{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qLiDBaSrzWu8"
      },
      "source": [
        "# Getting Started with RAG\n",
        "\n",
        "While large language models (LLMs) show powerful capabilities that power advanced use cases, they suffer from issues such as factual inconsistency and hallucination. Retrieval-augmented generation (RAG) is a powerful approach to enrich LLM capabilities and improve their reliability. RAG involves combining LLMs with external knowledge by enriching the prompt context with relevant information that helps accomplish a task.\n",
        "\n",
        "This tutorial shows how to getting started with RAG by leveraging vector store and open-source LLMs. To showcase the power of RAG, this use case will cover building a RAG system that suggests short and easy to read ML paper titles from original ML paper titles. Paper tiles can be too technical for a general audience so using RAG to generate short titles based on previously created short titles can make research paper titles more accessible and used for science communication such as in the form of newsletters or blogs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UsChkJxn2CSZ"
      },
      "source": [
        "Before getting started, let's first install the libraries we will use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "9gy2ijb5zWu-"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install chromadb tqdm fireworks-ai python-dotenv pandas\n",
        "!pip install sentence-transformers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-3iYAReMAe1q"
      },
      "source": [
        "Before continuing, you need to obtain a Fireworks API Key to use the Mistral 7B model.\n",
        "\n",
        "Checkout this quick guide to obtain your Fireworks API Key: https://readme.fireworks.ai/docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "pBSEMYFszWu_"
      },
      "outputs": [],
      "source": [
        "import fireworks.client\n",
        "import os\n",
        "import dotenv\n",
        "import chromadb\n",
        "import json\n",
        "from tqdm.auto import tqdm\n",
        "import pandas as pd\n",
        "import random\n",
        "\n",
        "# you can set envs using Colab secrets\n",
        "dotenv.load_dotenv()\n",
        "\n",
        "fireworks.client.api_key = os.getenv(\"FIREWORKS_API_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q9v_0IEDtgov"
      },
      "source": [
        "## Getting Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J8wYyXMizWu_"
      },
      "source": [
        "Let's define a function to get completions from the Fireworks inference platform."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "1hZldHjmzWvA"
      },
      "outputs": [],
      "source": [
        "def get_completion(prompt, model=None, max_tokens=50):\n",
        "\n",
        "    fw_model_dir = \"accounts/fireworks/models/\"\n",
        "\n",
        "    if model is None:\n",
        "        model = fw_model_dir + \"llama-v2-7b\"\n",
        "    else:\n",
        "        model = fw_model_dir + model\n",
        "\n",
        "    completion = fireworks.client.Completion.create(\n",
        "        model=model,\n",
        "        prompt=prompt,\n",
        "        max_tokens=max_tokens,\n",
        "        temperature=0\n",
        "    )\n",
        "\n",
        "    return completion.choices[0].text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ys59WgrGzWvA"
      },
      "source": [
        "Let's first try the function with a simple prompt:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "edQeSLODzWvA",
        "outputId": "fb0174c2-1490-424c-98a5-fd954a833d40"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "' Katie and I am a 20 year old student at the University of Leeds. I am currently studying a BA in English Literature and Creative Writing. I have been working as a tutor for over 3 years now and I'"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "get_completion(\"Hello, my name is\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CwDjmi8EzWvB"
      },
      "source": [
        "Now let's test with Mistral-7B-Instruct:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "O9TwL-2DzWvB",
        "outputId": "affb05c1-6a61-4a31-8dee-f2ac6a769fd9"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "' [Your Name]. I am a [Your Profession/Occupation]. I am writing to [Purpose of Writing].\\n\\nI am writing to [Purpose of Writing] because [Reason for Writing]. I believe that ['"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "mistral_llm = \"mistral-7b-instruct-4k\"\n",
        "\n",
        "get_completion(\"Hello, my name is\", model=mistral_llm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LZdlBtx-zWvB"
      },
      "source": [
        "The Mistral 7B Instruct model needs to be instructed using special instruction tokens `[INST] <instruction> [/INST]` to get the right behavior. You can find more instructions on how to prompt Mistral 7B Instruct here: https://docs.mistral.ai/llm/mistral-instruct-v0.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "ITURzGa9zWvC",
        "outputId": "5ba3d395-9887-438a-d2a8-808c1598cda6"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\".\\n1. Why don't scientists trust atoms? Because they make up everything!\\n2. Did you hear about the mathematician who’s afraid of negative numbers? He will stop at nothing to avoid them.\""
            ]
          },
          "execution_count": 20,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "mistral_llm = \"mistral-7b-instruct-4k\"\n",
        "\n",
        "get_completion(\"Tell me 2 jokes\", model=mistral_llm)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "QN6Y2y1GzWvC",
        "outputId": "4181f5e5-cb73-47d4-c1f3-1ea1e2f8eb8b"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\" Sure, here are two jokes for you:\\n\\n1. Why don't scientists trust atoms? Because they make up everything!\\n2. Why did the tomato turn red? Because it saw the salad dressing!\""
            ]
          },
          "execution_count": 21,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "mistral_llm = \"mistral-7b-instruct-4k\"\n",
        "\n",
        "get_completion(\"[INST]Tell me 2 jokes[/INST]\", model=mistral_llm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jZHyn0tJzWvC"
      },
      "source": [
        "Now let's try with a more complex prompt that involves instructions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "id": "dVwL--2kzWvC",
        "outputId": "1e1d37cd-ba24-4b0a-ab2f-62b516d515f1"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\" Dear John Doe,\\n\\nWe, Tom and Mary, would like to extend our heartfelt gratitude for your attendance at our wedding. It was a pleasure to have you there, and we truly appreciate the effort you made to be a part of our special day.\\n\\nWe were thrilled to learn about your fun fact - climbing Mount Everest is an incredible accomplishment! We hope you had a safe and memorable journey.\\n\\nThank you again for joining us on this special occasion. We hope to stay in touch and catch up on all the amazing things you've been up to.\\n\\nWith love,\\n\\nTom and Mary\""
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "prompt = \"\"\"[INST]\n",
        "Given the following wedding guest data, write a very short 3-sentences thank you letter:\n",
        "\n",
        "{\n",
        "  \"name\": \"John Doe\",\n",
        "  \"relationship\": \"Bride's cousin\",\n",
        "  \"hometown\": \"New York, NY\",\n",
        "  \"fun_fact\": \"Climbed Mount Everest in 2020\",\n",
        "  \"attending_with\": \"Sophia Smith\",\n",
        "  \"bride_groom_name\": \"Tom and Mary\"\n",
        "}\n",
        "\n",
        "Use only the data provided in the JSON object above.\n",
        "\n",
        "The senders of the letter is the bride and groom, Tom and Mary.\n",
        "[/INST]\"\"\"\n",
        "\n",
        "get_completion(prompt, model=mistral_llm, max_tokens=150)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9SROezW6zWvD"
      },
      "source": [
        "## RAG Use Case: Generating Short Paper Titles\n",
        "\n",
        "For the RAG use case, we will be using [a dataset](https://github.com/dair-ai/ML-Papers-of-the-Week/tree/main/research) that contains a list of weekly top trending ML papers.\n",
        "\n",
        "The user will provide an original paper title. We will then take that input and then use the dataset to generate a context of short and catchy papers titles that will help generate catchy title for the original input title.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0sw7Uk6qzWvD"
      },
      "source": [
        "### Step 1: Load the Dataset\n",
        "\n",
        "Let's first load the dataset we will use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "mv1z1LLczWvD"
      },
      "outputs": [],
      "source": [
        "# load dataset from data/ folder to pandas dataframe\n",
        "# dataset contains column names\n",
        "\n",
        "ml_papers = pd.read_csv(\"../data/ml-potw-10232023.csv\", header=0)\n",
        "\n",
        "# remove rows with empty titles or descriptions\n",
        "ml_papers = ml_papers.dropna(subset=[\"Title\", \"Description\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 467
        },
        "id": "ErNNQRPqzWvD",
        "outputId": "817472e8-8b94-4fd6-c207-b26c62d2babd"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Title</th>\n",
              "      <th>Description</th>\n",
              "      <th>PaperURL</th>\n",
              "      <th>TweetURL</th>\n",
              "      <th>Abstract</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Llemma</td>\n",
              "      <td>an LLM for mathematics which is based on conti...</td>\n",
              "      <td>https://arxiv.org/abs/2310.10631</td>\n",
              "      <td>https://x.com/zhangir_azerbay/status/171409802...</td>\n",
              "      <td>We present Llemma, a large language model for ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>LLMs for Software Engineering</td>\n",
              "      <td>a comprehensive survey of LLMs for software en...</td>\n",
              "      <td>https://arxiv.org/abs/2310.03533</td>\n",
              "      <td>https://x.com/omarsar0/status/1713940983199506...</td>\n",
              "      <td>This paper provides a survey of the emerging a...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Self-RAG</td>\n",
              "      <td>presents a new retrieval-augmented framework t...</td>\n",
              "      <td>https://arxiv.org/abs/2310.11511</td>\n",
              "      <td>https://x.com/AkariAsai/status/171511027707796...</td>\n",
              "      <td>Despite their remarkable capabilities, large l...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Retrieval-Augmentation for Long-form Question ...</td>\n",
              "      <td>explores retrieval-augmented language models o...</td>\n",
              "      <td>https://arxiv.org/abs/2310.12150</td>\n",
              "      <td>https://x.com/omarsar0/status/1714986431859282...</td>\n",
              "      <td>We present a study of retrieval-augmented lang...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>GenBench</td>\n",
              "      <td>presents a framework for characterizing and un...</td>\n",
              "      <td>https://www.nature.com/articles/s42256-023-007...</td>\n",
              "      <td>https://x.com/AIatMeta/status/1715041427283902...</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                                               Title  \\\n",
              "0                                             Llemma   \n",
              "1                      LLMs for Software Engineering   \n",
              "2                                           Self-RAG   \n",
              "3  Retrieval-Augmentation for Long-form Question ...   \n",
              "4                                           GenBench   \n",
              "\n",
              "                                         Description  \\\n",
              "0  an LLM for mathematics which is based on conti...   \n",
              "1  a comprehensive survey of LLMs for software en...   \n",
              "2  presents a new retrieval-augmented framework t...   \n",
              "3  explores retrieval-augmented language models o...   \n",
              "4  presents a framework for characterizing and un...   \n",
              "\n",
              "                                            PaperURL  \\\n",
              "0                   https://arxiv.org/abs/2310.10631   \n",
              "1                   https://arxiv.org/abs/2310.03533   \n",
              "2                   https://arxiv.org/abs/2310.11511   \n",
              "3                   https://arxiv.org/abs/2310.12150   \n",
              "4  https://www.nature.com/articles/s42256-023-007...   \n",
              "\n",
              "                                            TweetURL  \\\n",
              "0  https://x.com/zhangir_azerbay/status/171409802...   \n",
              "1  https://x.com/omarsar0/status/1713940983199506...   \n",
              "2  https://x.com/AkariAsai/status/171511027707796...   \n",
              "3  https://x.com/omarsar0/status/1714986431859282...   \n",
              "4  https://x.com/AIatMeta/status/1715041427283902...   \n",
              "\n",
              "                                            Abstract  \n",
              "0  We present Llemma, a large language model for ...  \n",
              "1  This paper provides a survey of the emerging a...  \n",
              "2  Despite their remarkable capabilities, large l...  \n",
              "3  We present a study of retrieval-augmented lang...  \n",
              "4                                                NaN  "
            ]
          },
          "execution_count": 24,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "ml_papers.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "KzyvzYcNzWvD"
      },
      "outputs": [],
      "source": [
        "# convert dataframe to list of dicts with Title and Description columns only\n",
        "\n",
        "ml_papers_dict = ml_papers.to_dict(orient=\"records\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "F3LUGNHIzWvE",
        "outputId": "3b1aa123-e316-488f-d0a4-0369bb2f75dd"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'Title': 'Llemma',\n",
              " 'Description': 'an LLM for mathematics which is based on continued pretraining from Code Llama on the Proof-Pile-2 dataset; the dataset involves scientific paper, web data containing mathematics, and mathematical code; Llemma outperforms open base models and the unreleased Minerva on the MATH benchmark; the model is released, including dataset and code to replicate experiments.',\n",
              " 'PaperURL': 'https://arxiv.org/abs/2310.10631',\n",
              " 'TweetURL': 'https://x.com/zhangir_azerbay/status/1714098025956864031?s=20',\n",
              " 'Abstract': 'We present Llemma, a large language model for mathematics. We continue pretraining Code Llama on the Proof-Pile-2, a mixture of scientific papers, web data containing mathematics, and mathematical code, yielding Llemma. On the MATH benchmark Llemma outperforms all known open base models, as well as the unreleased Minerva model suite on an equi-parameter basis. Moreover, Llemma is capable of tool use and formal theorem proving without any further finetuning. We openly release all artifacts, including 7 billion and 34 billion parameter models, the Proof-Pile-2, and code to replicate our experiments.'}"
            ]
          },
          "execution_count": 26,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "ml_papers_dict[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WwfW0XoxzWvE"
      },
      "source": [
        "We will be using SentenceTransformer for generating embeddings that we will store to a chroma document store."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "1zFDOicHzWvE"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            ".gitattributes: 100%|██████████| 1.18k/1.18k [00:00<00:00, 194kB/s]\n",
            "1_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 204kB/s]\n",
            "README.md: 100%|██████████| 10.6k/10.6k [00:00<00:00, 7.64MB/s]\n",
            "config.json: 100%|██████████| 612/612 [00:00<00:00, 679kB/s]\n",
            "config_sentence_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 94.0kB/s]\n",
            "data_config.json: 100%|██████████| 39.3k/39.3k [00:00<00:00, 7.80MB/s]\n",
            "pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:03<00:00, 24.3MB/s]\n",
            "sentence_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 55.4kB/s]\n",
            "special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 161kB/s]\n",
            "tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 6.15MB/s]\n",
            "tokenizer_config.json: 100%|██████████| 350/350 [00:00<00:00, 286kB/s]\n",
            "train_script.py: 100%|██████████| 13.2k/13.2k [00:00<00:00, 12.2MB/s]\n",
            "vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 9.15MB/s]\n",
            "modules.json: 100%|██████████| 349/349 [00:00<00:00, 500kB/s]\n"
          ]
        }
      ],
      "source": [
        "from chromadb import Documents, EmbeddingFunction, Embeddings\n",
        "from sentence_transformers import SentenceTransformer\n",
        "embedding_model = SentenceTransformer('all-MiniLM-L6-v2')\n",
        "\n",
        "class MyEmbeddingFunction(EmbeddingFunction):\n",
        "    def __call__(self, input: Documents) -> Embeddings:\n",
        "        batch_embeddings = embedding_model.encode(input)\n",
        "        return batch_embeddings.tolist()\n",
        "\n",
        "embed_fn = MyEmbeddingFunction()\n",
        "\n",
        "# Initialize the chromadb directory, and client.\n",
        "client = chromadb.PersistentClient(path=\"./chromadb\")\n",
        "\n",
        "# create collection\n",
        "collection = client.get_or_create_collection(\n",
        "    name=f\"ml-papers-nov-2023\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eu0_-PREzWvE"
      },
      "source": [
        "We will now generate embeddings for batches:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "b9da413d4f84436ab5dc0fd10d237b0a",
            "c269f70baec246288519dbb2517c05c0",
            "b568c3c04efb49acb00e44aadc247735",
            "868c943d8a82435a8c3df6f32a3cc433",
            "51ad257305a0438fbed46c613d2d59fb",
            "cfbba792e6054178b504c2e2bbc23b2f",
            "54c8ecbae313483e82879016cc49bd25",
            "8e06dc7bf2f94d63a69b651f594ecf74",
            "bacb31eccd5c4a2dba61503a658333f2",
            "b9545ef27bf24e0c86d713ae8a3c0d2c",
            "0285ca3156854ca09a252540ad6a43ff"
          ]
        },
        "id": "kUauose2zWvE",
        "outputId": "18b7bf6a-0341-4843-8168-9875d78a6de9"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 9/9 [00:01<00:00,  7.62it/s]\n"
          ]
        }
      ],
      "source": [
        "# Generate embeddings, and index titles in batches\n",
        "batch_size = 50\n",
        "\n",
        "# loop through batches and generated + store embeddings\n",
        "for i in tqdm(range(0, len(ml_papers_dict), batch_size)):\n",
        "\n",
        "    i_end = min(i + batch_size, len(ml_papers_dict))\n",
        "    batch = ml_papers_dict[i : i + batch_size]\n",
        "\n",
        "    # Replace title with \"No Title\" if empty string\n",
        "    batch_titles = [str(paper[\"Title\"]) if str(paper[\"Title\"]) != \"\" else \"No Title\" for paper in batch]\n",
        "    batch_ids = [str(sum(ord(c) + random.randint(1, 10000) for c in paper[\"Title\"])) for paper in batch]\n",
        "    batch_metadata = [dict(url=paper[\"PaperURL\"],\n",
        "                           abstract=paper['Abstract'])\n",
        "                           for paper in batch]\n",
        "\n",
        "    # generate embeddings\n",
        "    batch_embeddings = embedding_model.encode(batch_titles)\n",
        "\n",
        "    # upsert to chromadb\n",
        "    collection.upsert(\n",
        "        ids=batch_ids,\n",
        "        metadatas=batch_metadata,\n",
        "        documents=batch_titles,\n",
        "        embeddings=batch_embeddings.tolist(),\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1xrbURsMzWvF"
      },
      "source": [
        "Now we can test the retriever:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YoDlxtZhzWvF",
        "outputId": "c258fabb-452d-4740-9073-3d3cf7791bb7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[['LLMs for Software Engineering', 'Communicative Agents for Software Development']]\n"
          ]
        }
      ],
      "source": [
        "collection = client.get_or_create_collection(\n",
        "    name=f\"ml-papers-nov-2023\",\n",
        "    embedding_function=embed_fn\n",
        ")\n",
        "\n",
        "retriever_results = collection.query(\n",
        "    query_texts=[\"Software Engineering\"],\n",
        "    n_results=2,\n",
        ")\n",
        "\n",
        "print(retriever_results[\"documents\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NUHeag1XzWvF"
      },
      "source": [
        "Now let's put together our final prompt:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x_A0VZ8YzWvF",
        "outputId": "2b3074dc-381e-4cc0-9ee8-ea90673e0da9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model Suggestions:\n",
            "\n",
            "1. S3Eval: A Comprehensive Evaluation Suite for Large Language Models\n",
            "2. Synthetic and Scalable Evaluation for Large Language Models\n",
            "3. Systematic Evaluation of Large Language Models with S3Eval\n",
            "4. S3Eval: A Synthetic and Scalable Approach to Language Model Evaluation\n",
            "5. S3Eval: A Synthetic and Scalable Evaluation Suite for Large Language Models\n",
            "\n",
            "\n",
            "\n",
            "Prompt Template:\n",
            "[INST]\n",
            "\n",
            "Your main task is to generate 5 SUGGESTED_TITLES based for the PAPER_TITLE\n",
            "\n",
            "You should mimic a similar style and length as SHORT_TITLES but PLEASE DO NOT include titles from SHORT_TITLES in the SUGGESTED_TITLES, only generate versions of the PAPER_TILE.\n",
            "\n",
            "PAPER_TITLE: S3Eval: A Synthetic, Scalable, Systematic Evaluation Suite for Large Language Models\n",
            "\n",
            "SHORT_TITLES: Pythia: A Suite for Analyzing Large Language Models Across Training and Scaling\n",
            "ChemCrow: Augmenting large-language models with chemistry tools\n",
            "A Survey of Large Language Models\n",
            "LLaMA: Open and Efficient Foundation Language Models\n",
            "SparseGPT: Massive Language Models Can Be Accurately Pruned In One-Shot\n",
            "REPLUG: Retrieval-Augmented Black-Box Language Models\n",
            "LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention\n",
            "Auditing large language models: a three-layered approach\n",
            "Fine-Tuning Language Models with Just Forward Passes\n",
            "DERA: Enhancing Large Language Model Completions with Dialog-Enabled Resolving Agents\n",
            "\n",
            "SUGGESTED_TITLES:\n",
            "\n",
            "[/INST]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# user query\n",
        "user_query = \"S3Eval: A Synthetic, Scalable, Systematic Evaluation Suite for Large Language Models\"\n",
        "\n",
        "# query for user query\n",
        "results = collection.query(\n",
        "    query_texts=[user_query],\n",
        "    n_results=10,\n",
        ")\n",
        "\n",
        "# concatenate titles into a single string\n",
        "short_titles = '\\n'.join(results['documents'][0])\n",
        "\n",
        "prompt_template = f'''[INST]\n",
        "\n",
        "Your main task is to generate 5 SUGGESTED_TITLES based for the PAPER_TITLE\n",
        "\n",
        "You should mimic a similar style and length as SHORT_TITLES but PLEASE DO NOT include titles from SHORT_TITLES in the SUGGESTED_TITLES, only generate versions of the PAPER_TILE.\n",
        "\n",
        "PAPER_TITLE: {user_query}\n",
        "\n",
        "SHORT_TITLES: {short_titles}\n",
        "\n",
        "SUGGESTED_TITLES:\n",
        "\n",
        "[/INST]\n",
        "'''\n",
        "\n",
        "responses = get_completion(prompt_template, model=mistral_llm, max_tokens=2000)\n",
        "suggested_titles = ''.join([str(r) for r in responses])\n",
        "\n",
        "# Print the suggestions.\n",
        "print(\"Model Suggestions:\")\n",
        "print(suggested_titles)\n",
        "print(\"\\n\\n\\nPrompt Template:\")\n",
        "print(prompt_template)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cSAJcQ6Y2cNt"
      },
      "source": [
        "As you can see, the short titles generated by the LLM are somewhat okay. This use case still needs a lot more work and could potentially benefit from finetuning as well. For the purpose of this tutorial, we have provided a simple application of RAG using open-source models from Firework's blazing-fast models.\n",
        "\n",
        "Try out other open-source models here: https://app.fireworks.ai/models\n",
        "\n",
        "Read more about the Fireworks APIs here: https://readme.fireworks.ai/reference/createchatcompletion\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "rag",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0285ca3156854ca09a252540ad6a43ff": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "51ad257305a0438fbed46c613d2d59fb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "54c8ecbae313483e82879016cc49bd25": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "868c943d8a82435a8c3df6f32a3cc433": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b9545ef27bf24e0c86d713ae8a3c0d2c",
            "placeholder": "​",
            "style": "IPY_MODEL_0285ca3156854ca09a252540ad6a43ff",
            "value": " 9/9 [00:04&lt;00:00,  2.22it/s]"
          }
        },
        "8e06dc7bf2f94d63a69b651f594ecf74": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b568c3c04efb49acb00e44aadc247735": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8e06dc7bf2f94d63a69b651f594ecf74",
            "max": 9,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_bacb31eccd5c4a2dba61503a658333f2",
            "value": 9
          }
        },
        "b9545ef27bf24e0c86d713ae8a3c0d2c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b9da413d4f84436ab5dc0fd10d237b0a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_c269f70baec246288519dbb2517c05c0",
              "IPY_MODEL_b568c3c04efb49acb00e44aadc247735",
              "IPY_MODEL_868c943d8a82435a8c3df6f32a3cc433"
            ],
            "layout": "IPY_MODEL_51ad257305a0438fbed46c613d2d59fb"
          }
        },
        "bacb31eccd5c4a2dba61503a658333f2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c269f70baec246288519dbb2517c05c0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cfbba792e6054178b504c2e2bbc23b2f",
            "placeholder": "​",
            "style": "IPY_MODEL_54c8ecbae313483e82879016cc49bd25",
            "value": "100%"
          }
        },
        "cfbba792e6054178b504c2e2bbc23b2f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
